{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "repo_root = '..'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get MedQA-USMLE dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, download the original MedQA dataset: https://github.com/jind11/MedQA. \n",
    "Put the unzipped folder in `data/medqa_usmle/raw`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Prepare `statement` data following CommonsenseQA, OpenBookQA\n",
    "medqa_root = f'{repo_root}/data/medqa_usmle'\n",
    "os.system(f'mkdir -p {medqa_root}/statement')\n",
    "\n",
    "for fname in [\"train\", \"dev\", \"test\"]:\n",
    "    with open(f\"{medqa_root}/raw/questions/US/4_options/phrases_no_exclude_{fname}.jsonl\") as f:\n",
    "        lines = f.readlines()\n",
    "    examples = []\n",
    "    for i in tqdm(range(len(lines))):\n",
    "        line = json.loads(lines[i])\n",
    "        _id  = f\"train-{i:05d}\"\n",
    "        answerKey = line[\"answer_idx\"]\n",
    "        stem      = line[\"question\"]    \n",
    "        choices   = [{\"label\": k, \"text\": line[\"options\"][k]} for k in \"ABCD\"]\n",
    "        stmts     = [{\"statement\": stem +\" \"+ c[\"text\"]} for c in choices]\n",
    "        ex_obj    = {\"id\": _id, \n",
    "                     \"question\": {\"stem\": stem, \"choices\": choices}, \n",
    "                     \"answerKey\": answerKey, \n",
    "                     \"statements\": stmts\n",
    "                    }\n",
    "        examples.append(ex_obj)\n",
    "    with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\", 'w') as fout:\n",
    "        for dic in examples:\n",
    "            print (json.dumps(dic), file=fout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Link entities to KG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, install the scispacy model:\n",
    "```\n",
    "pip install scispacy==0.3.0\n",
    "pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.3.0/en_core_sci_sm-0.3.0.tar.gz\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load scispacy entity linker\n",
    "import spacy\n",
    "import scispacy\n",
    "from scispacy.linking import EntityLinker\n",
    "\n",
    "def load_entity_linker(threshold=0.90):\n",
    "    nlp = spacy.load(\"en_core_sci_sm\")\n",
    "    linker = EntityLinker(\n",
    "        resolve_abbreviations=True,\n",
    "        name=\"umls\",\n",
    "        threshold=threshold)\n",
    "    nlp.add_pipe(linker)\n",
    "    return nlp, linker\n",
    "\n",
    "nlp, linker = load_entity_linker()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def entity_linking_to_umls(sentence, nlp, linker):\n",
    "    doc = nlp(sentence)\n",
    "    entities = doc.ents\n",
    "    all_entities_results = []\n",
    "    for mm in range(len(entities)):\n",
    "        entity_text = entities[mm].text\n",
    "        entity_start = entities[mm].start\n",
    "        entity_end = entities[mm].end\n",
    "        all_linked_entities = entities[mm]._.kb_ents\n",
    "        all_entity_results = []\n",
    "        for ii in range(len(all_linked_entities)):\n",
    "            curr_concept_id = all_linked_entities[ii][0]\n",
    "            curr_score = all_linked_entities[ii][1]\n",
    "            curr_scispacy_entity = linker.kb.cui_to_entity[all_linked_entities[ii][0]]\n",
    "            curr_canonical_name = curr_scispacy_entity.canonical_name\n",
    "            curr_TUIs = curr_scispacy_entity.types\n",
    "            curr_entity_result = {\"Canonical Name\": curr_canonical_name, \"Concept ID\": curr_concept_id,\n",
    "                                  \"TUIs\": curr_TUIs, \"Score\": curr_score}\n",
    "            all_entity_results.append(curr_entity_result)\n",
    "        curr_entities_result = {\"text\": entity_text, \"start\": entity_start, \"end\": entity_end, \n",
    "                                \"start_char\": entities[mm].start_char, \"end_char\": entities[mm].end_char,\n",
    "                                \"linking_results\": all_entity_results}\n",
    "        all_entities_results.append(curr_entities_result)\n",
    "    return all_entities_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Example\n",
    "sent = \"A 5-year-old girl is brought to the emergency department by her mother because of multiple episodes of nausea and vomiting that last about 2 hours. During this period, she has had 6–8 episodes of bilious vomiting and abdominal pain. The vomiting was preceded by fatigue.\"\n",
    "ent_link_results = entity_linking_to_umls(sent, nlp, linker)\n",
    "ent_link_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Run entity linking to UMLS for all questions\n",
    "def process(input):\n",
    "    nlp, linker = load_entity_linker()\n",
    "    stmts = input\n",
    "    for stmt in tqdm(stmts):\n",
    "        stem = stmt['question']['stem']\n",
    "        stem = stem[:3500]\n",
    "        stmt['question']['stem_ents'] = entity_linking_to_umls(stem, nlp, linker)\n",
    "        for ii, choice in enumerate(stmt['question']['choices']):\n",
    "            text = stmt['question']['choices'][ii]['text']\n",
    "            stmt['question']['choices'][ii]['text_ents'] = entity_linking_to_umls(text, nlp, linker)\n",
    "    return stmts\n",
    "\n",
    "for fname in [\"dev\", \"test\", \"train\"]:\n",
    "    with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\") as fin:\n",
    "        stmts = [json.loads(line) for line in fin]\n",
    "        res = process(stmts)  \n",
    "    with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\", 'w') as fout:\n",
    "        for dic in res:\n",
    "            print (json.dumps(dic), file=fout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Convert UMLS entity linking to DDB entity linking (our KG)\n",
    "umls_to_ddb = {}\n",
    "with open(f'{repo_root}/data/ddb/ddb_to_umls_cui.txt') as f:\n",
    "    for line in f.readlines()[1:]:\n",
    "        elms = line.split(\"\\t\")\n",
    "        umls_to_ddb[elms[2]] = elms[1]\n",
    "\n",
    "def map_to_ddb(ent_obj):\n",
    "    res = []\n",
    "    for ent_cand in ent_obj['linking_results']:\n",
    "        CUI  = ent_cand['Concept ID']\n",
    "        name = ent_cand['Canonical Name']\n",
    "        if CUI in umls_to_ddb:\n",
    "            ddb_cid = umls_to_ddb[CUI]\n",
    "            res.append((ddb_cid, name))\n",
    "    return res\n",
    "\n",
    "def process(fname):\n",
    "    with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\") as fin:\n",
    "        stmts = [json.loads(line) for line in fin]\n",
    "    with open(f\"{medqa_root}/grounded/{fname}.grounded.jsonl\", 'w') as fout:\n",
    "        for stmt in tqdm(stmts):\n",
    "            sent = stmt['question']['stem']\n",
    "            qc = []\n",
    "            qc_names = []\n",
    "            for ent_obj in stmt['question']['stem_ents']:\n",
    "                res = map_to_ddb(ent_obj)\n",
    "                for elm in res:\n",
    "                    ddb_cid, name = elm\n",
    "                    qc.append(ddb_cid)\n",
    "                    qc_names.append(name)\n",
    "            for cid, choice in enumerate(stmt['question']['choices']):\n",
    "                ans = choice['text']\n",
    "                ac = []\n",
    "                ac_names = []\n",
    "                for ent_obj in choice['text_ents']:\n",
    "                    res = map_to_ddb(ent_obj)\n",
    "                    for elm in res:\n",
    "                        ddb_cid, name = elm\n",
    "                        ac.append(ddb_cid)\n",
    "                        ac_names.append(name)\n",
    "                out = {'sent': sent, 'ans': ans, 'qc': qc, 'qc_names': qc_names, 'ac': ac, 'ac_names': ac_names}\n",
    "                print (json.dumps(out), file=fout)     \n",
    "\n",
    "os.system(f'mkdir -p {medqa_root}/grounded')\n",
    "for fname in [\"dev\", \"test\", \"train\"]:\n",
    "    process(fname) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load knowledge graph (KG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load our KG, which is based on Disease Database + DrugBank."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_ddb():            \n",
    "    with open(f'{repo_root}/data/ddb/ddb_names.json') as f:\n",
    "        all_names = json.load(f)\n",
    "    with open(f'{repo_root}/data/ddb/ddb_relas.json') as f:\n",
    "        all_relas = json.load(f)\n",
    "    relas_lst = []\n",
    "    for key, val in all_relas.items():\n",
    "        relas_lst.append(val)\n",
    "        \n",
    "    ddb_ptr_to_preferred_name = {}\n",
    "    ddb_ptr_to_name = defaultdict(list)\n",
    "    ddb_name_to_ptr = {}\n",
    "    for key, val in all_names.items():\n",
    "        item_name = key\n",
    "        item_ptr = val[0]\n",
    "        item_preferred = val[1]\n",
    "        if item_preferred == \"1\":\n",
    "            ddb_ptr_to_preferred_name[item_ptr] = item_name\n",
    "        ddb_name_to_ptr[item_name] = item_ptr\n",
    "        ddb_ptr_to_name[item_ptr].append(item_name)\n",
    "        \n",
    "    return (relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name)\n",
    "\n",
    "\n",
    "relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name = load_ddb()\n",
    "\n",
    "\n",
    "ddb_ptr_lst, ddb_names_lst = [], []\n",
    "for key, val in ddb_ptr_to_preferred_name.items():\n",
    "    ddb_ptr_lst.append(key)\n",
    "    ddb_names_lst.append(val)\n",
    "\n",
    "with open(f\"{repo_root}/data/ddb/vocab.txt\", \"w\") as fout:\n",
    "    for ddb_name in ddb_names_lst:\n",
    "        print (ddb_name, file=fout)\n",
    "\n",
    "with open(f\"{repo_root}/data/ddb/ptrs.txt\", \"w\") as fout:\n",
    "    for ddb_ptr in ddb_ptr_lst:\n",
    "        print (ddb_ptr, file=fout)\n",
    "\n",
    "id2concept = ddb_ptr_lst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(ddb_ptr_to_name), len(ddb_ptr_to_preferred_name), len(ddb_name_to_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ddb_name_to_ptr['Ethanol'], ddb_name_to_ptr['Serine']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_relations = [\n",
    "    'belongs_to_the_category_of',\n",
    "    'is_a_category',\n",
    "    'may_cause',\n",
    "    'is_a_subtype_of',\n",
    "    'is_a_risk_factor_of',\n",
    "    'is_associated_with',\n",
    "    'may_contraindicate',\n",
    "    'interacts_with',\n",
    "    'belongs_to_the_drug_family_of',\n",
    "    'belongs_to_drug_super-family',\n",
    "    'is_a_vector_for',\n",
    "    'may_be_allelic_with',\n",
    "    'see_also',\n",
    "    'is_an_ingradient_of',\n",
    "    'may_treat'\n",
    "]\n",
    "\n",
    "relas_dict = {\"0\": 0, \"1\": 1, \"2\": 2, \"3\": 3, \"4\": 4, \"6\": 5, \"10\": 6, \"12\": 7, \"16\": 8, \"17\": 9, \"18\": 10,\n",
    "             \"20\": 11, \"26\": 12, \"30\": 13, \"233\": 14}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "def construct_graph():\n",
    "    concept2id = {w: i for i, w in enumerate(id2concept)}\n",
    "    id2relation = merged_relations\n",
    "    relation2id = {r: i for i, r in enumerate(id2relation)}\n",
    "    graph = nx.MultiDiGraph()\n",
    "    attrs = set()\n",
    "    for relation in relas_lst:\n",
    "        subj = concept2id[relation[0]]\n",
    "        obj = concept2id[relation[1]]\n",
    "        rel = relas_dict[relation[2]]\n",
    "        weight = 1.\n",
    "        graph.add_edge(subj, obj, rel=rel, weight=weight)\n",
    "        attrs.add((subj, obj, rel))\n",
    "        graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)\n",
    "        attrs.add((obj, subj, rel + len(relation2id)))\n",
    "    output_path = f\"{repo_root}/data/ddb/ddb.graph\"\n",
    "    nx.write_gpickle(graph, output_path)\n",
    "    return concept2id, id2relation, relation2id, graph\n",
    "\n",
    "concept2id, id2relation, relation2id, KG = construct_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get KG subgraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get KG subgraph for each question, following the method used for CommonsenseQA + ConceptNet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_kg():\n",
    "    global cpnet, cpnet_simple\n",
    "    cpnet = KG\n",
    "    cpnet_simple = nx.Graph()\n",
    "    for u, v, data in cpnet.edges(data=True):\n",
    "        w = data['weight'] if 'weight' in data else 1.0\n",
    "        if cpnet_simple.has_edge(u, v):\n",
    "            cpnet_simple[u][v]['weight'] += w\n",
    "        else:\n",
    "            cpnet_simple.add_edge(u, v, weight=w)\n",
    "\n",
    "load_kg()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix, coo_matrix\n",
    "from multiprocessing import Pool\n",
    "\n",
    "def concepts2adj(node_ids):\n",
    "    global id2relation\n",
    "    cids = np.array(node_ids, dtype=np.int32)\n",
    "    n_rel = len(id2relation)\n",
    "    n_node = cids.shape[0]\n",
    "    adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)\n",
    "    for s in range(n_node):\n",
    "        for t in range(n_node):\n",
    "            s_c, t_c = cids[s], cids[t]\n",
    "            if cpnet.has_edge(s_c, t_c):\n",
    "                for e_attr in cpnet[s_c][t_c].values():\n",
    "                    if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:\n",
    "                        adj[e_attr['rel']][s][t] = 1\n",
    "    adj = coo_matrix(adj.reshape(-1, n_node))\n",
    "    return adj, cids\n",
    "\n",
    "def concepts_to_adj_matrices_2hop_all_pair(data):\n",
    "    qc_ids, ac_ids = data\n",
    "    qa_nodes = set(qc_ids) | set(ac_ids)\n",
    "    extra_nodes = set()\n",
    "    for qid in qa_nodes:\n",
    "        for aid in qa_nodes:\n",
    "            if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:\n",
    "                extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])\n",
    "    extra_nodes = extra_nodes - qa_nodes\n",
    "    schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)\n",
    "    arange = np.arange(len(schema_graph))\n",
    "    qmask = arange < len(qc_ids)\n",
    "    amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))\n",
    "    adj, concepts = concepts2adj(schema_graph)\n",
    "    return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': None}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):\n",
    "    global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet\n",
    "\n",
    "    qa_data = []\n",
    "    with open(grounded_path, 'r', encoding='utf-8') as fin:\n",
    "        for line in fin:\n",
    "            dic = json.loads(line)\n",
    "            q_ids = set(concept2id[c] for c in dic['qc'])\n",
    "            if not q_ids:\n",
    "                q_ids = {concept2id['31770']} \n",
    "            a_ids = set(concept2id[c] for c in dic['ac'])\n",
    "            if not a_ids:\n",
    "                a_ids = {concept2id['325']}\n",
    "            q_ids = q_ids - a_ids\n",
    "            qa_data.append((q_ids, a_ids))\n",
    "\n",
    "    with Pool(num_processes) as p:\n",
    "        res = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data), total=len(qa_data)))\n",
    "    \n",
    "    lens = [len(e['concepts']) for e in res]\n",
    "    print ('mean #nodes', int(np.mean(lens)), 'med', int(np.median(lens)), '5th', int(np.percentile(lens, 5)), '95th', int(np.percentile(lens, 95)))\n",
    "\n",
    "    with open(output_path, 'wb') as fout:\n",
    "        pickle.dump(res, fout)\n",
    "\n",
    "    print(f'adj data saved to {output_path}')\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "os.system(f'mkdir -p {repo_root}/data/medqa_usmle/graph')\n",
    "\n",
    "for fname in [\"dev\", \"test\", \"train\"]:\n",
    "    grounded_path = f\"{repo_root}/data/medqa_usmle/grounded/{fname}.grounded.jsonl\"\n",
    "    kg_path       = f\"{repo_root}/data/ddb/ddb.graph\"\n",
    "    kg_vocab_path = f\"{repo_root}/data/ddb/ddb_ptrs.txt\"\n",
    "    output_path   = f\"{repo_root}/data/medqa_usmle/graph/{fname}.graph.adj.pk\"\n",
    "\n",
    "    generate_adj_data_from_grounded_concepts(grounded_path, kg_path, kg_vocab_path, output_path, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get KG entity embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModel, AutoConfig\n",
    "tokenizer  = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
    "bert_model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\"\")\n",
    "device = torch.device('cuda')\n",
    "bert_model.to(device)\n",
    "bert_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"{repo_root}/data/ddb/vocab.txt\") as f:\n",
    "    names = [line.strip() for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embs = []\n",
    "tensors = tokenizer(names, padding=True, truncation=True, return_tensors=\"pt\")\n",
    "with torch.no_grad():\n",
    "    for i, j in enumerate(tqdm(names)):\n",
    "        outputs = bert_model(input_ids=tensors[\"input_ids\"][i:i+1].to(device), \n",
    "                               attention_mask=tensors['attention_mask'][i:i+1].to(device))\n",
    "        out = np.array(outputs[1].squeeze().tolist()).reshape((1, -1))\n",
    "        embs.append(out)\n",
    "embs = np.concatenate(embs)\n",
    "np.save(f\"{repo_root}/data/ddb/ent_emb.npy\", embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ct",
   "language": "python",
   "name": "ct"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
